{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from path_explain import utils\n",
    "utils.set_up_environment(visible_devices='1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets\n",
    "import numpy as np\n",
    "import scipy\n",
    "from transformers import *\n",
    "from plot.text import text_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data and Model Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = 'sst-2'\n",
    "num_labels = len(glue_processors[task]().get_labels())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = DistilBertConfig.from_pretrained('.', num_labels=num_labels)\n",
    "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
    "model = TFDistilBertForSequenceClassification.from_pretrained('.', config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:Overwrite dataset info from restored data version.\n",
      "INFO:absl:Field info.description from disk and from code do not match. Keeping the one from code.\n",
      "INFO:absl:Field info.location from disk and from code do not match. Keeping the one from code.\n",
      "INFO:absl:Reusing dataset glue (/homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2)\n",
      "INFO:absl:Constructing tf.data.Dataset for split None, from /homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2\n"
     ]
    }
   ],
   "source": [
    "data, info = tensorflow_datasets.load('glue/sst2', with_info=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task=task)\n",
    "valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task=task)\n",
    "valid_dataset = valid_dataset.batch(16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_pred = model.predict(valid_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_input = []\n",
    "valid_labels = []\n",
    "for batch in valid_dataset:\n",
    "    valid_input.append(batch[0])\n",
    "    valid_labels.append(batch[1].numpy())\n",
    "valid_labels_np = np.concatenate(valid_labels, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 0.9071\n",
      "Positive Sentiment Accuracy: 0.9054\n",
      "Negative Sentiment Accuracy: 0.9089\n"
     ]
    }
   ],
   "source": [
    "valid_pred_max = np.argmax(valid_pred, axis=-1)\n",
    "accuracy = np.sum(valid_pred_max == valid_labels_np) / len(valid_labels_np)\n",
    "\n",
    "positive_mask = valid_labels_np == 1\n",
    "positive_accuracy = np.sum(valid_pred_max[positive_mask] == valid_labels_np[positive_mask]) / np.sum(positive_mask)\n",
    "\n",
    "negative_mask = valid_labels_np == 0\n",
    "negative_accuracy = np.sum(valid_pred_max[negative_mask] == valid_labels_np[negative_mask]) / np.sum(negative_mask)\n",
    "\n",
    "print('Validation Accuracy: {:.4f}'.format(accuracy))\n",
    "print('Positive Sentiment Accuracy: {:.4f}'.format(positive_accuracy))\n",
    "print('Negative Sentiment Accuracy: {:.4f}'.format(negative_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = ['This movie was bad',\n",
    "             'This movie was not bad',\n",
    "             'A movie',\n",
    "             'A bad movie',\n",
    "             'A bad, terrible movie',\n",
    "             'A bad, terrible, awful movie',\n",
    "             'A bad, terrible, awful, horrible movie']\n",
    "\n",
    "input_ids_all = []\n",
    "token_type_ids_all = []\n",
    "attention_mask_all = []\n",
    "\n",
    "for sentence in sentences:\n",
    "    encoded_sentence = tokenizer.encode_plus(sentence,\n",
    "                                             add_special_tokens=True,\n",
    "                                             return_tensors='tf',\n",
    "                                             pad_to_max_length=True,\n",
    "                                             max_length=128)\n",
    "    input_ids = encoded_sentence['input_ids']\n",
    "    token_type_ids = encoded_sentence['token_type_ids']\n",
    "    attention_mask = encoded_sentence['attention_mask']\n",
    "\n",
    "    input_ids_all.append(input_ids)\n",
    "    token_type_ids_all.append(token_type_ids)\n",
    "    attention_mask_all.append(attention_mask)\n",
    "\n",
    "encoded_sentences = {\n",
    "    'input_ids': tf.concat(input_ids_all, axis=0),\n",
    "    'token_type_ids': tf.concat(token_type_ids_all, axis=0),\n",
    "    'attention_mask': tf.concat(attention_mask_all, axis=0)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(7, 2), dtype=float32, numpy=\n",
       "array([[ 3.6690195, -2.6701076],\n",
       "       [-2.0752487,  1.9163918],\n",
       "       [-2.6472576,  2.4648838],\n",
       "       [ 3.611107 , -2.6552174],\n",
       "       [ 3.6766212, -2.7141526],\n",
       "       [ 3.7107847, -2.716646 ],\n",
       "       [ 3.7191744, -2.7317934]], dtype=float32)>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(encoded_sentences)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "({'input_ids': <tf.Tensor: shape=(16, 128), dtype=int32, numpy=\n",
      "array([[  101,  2009, 23283, ...,     0,     0,     0],\n",
      "       [  101,  2205,  2172, ...,     0,     0,     0],\n",
      "       [  101,  1037,  2307, ...,     0,     0,     0],\n",
      "       ...,\n",
      "       [  101,  2004,  1996, ...,     0,     0,     0],\n",
      "       [  101,  1996,  2516, ...,     0,     0,     0],\n",
      "       [  101,  9327,  1012, ...,     0,     0,     0]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(16, 128), dtype=int32, numpy=\n",
      "array([[1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       ...,\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'token_type_ids': <tf.Tensor: shape=(16, 128), dtype=int32, numpy=\n",
      "array([[0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       ...,\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>}, <tf.Tensor: shape=(16,), dtype=int64, numpy=array([1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1])>)\n"
     ]
    }
   ],
   "source": [
    "for item in valid_dataset.take(1):\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(16, 2), dtype=float32, numpy=\n",
       "array([[-3.9101574 ,  3.8021853 ],\n",
       "       [ 3.8041704 , -2.7883568 ],\n",
       "       [ 1.9888983 , -1.4849629 ],\n",
       "       [ 3.4512417 , -2.530995  ],\n",
       "       [-2.443832  ,  2.26392   ],\n",
       "       [-4.1048136 ,  4.016287  ],\n",
       "       [-1.5832212 ,  1.3562965 ],\n",
       "       [ 2.9390903 , -2.1744635 ],\n",
       "       [-0.35801238,  0.43716034],\n",
       "       [-4.2704735 ,  4.176792  ],\n",
       "       [-3.441905  ,  3.28892   ],\n",
       "       [-2.6926115 ,  2.5206203 ],\n",
       "       [ 2.647887  , -1.9844501 ],\n",
       "       [-4.2204533 ,  4.191584  ],\n",
       "       [ 0.8310715 , -0.6177254 ],\n",
       "       [-4.196546  ,  4.136347  ]], dtype=float32)>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(item[0])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = item[0]['input_ids']\n",
    "embeddings = model.distilbert.embeddings(input_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# attention_mask = tf.ones(input_ids.shape)\n",
    "# attention_mask = tf.cast(attention_mask, dtype=tf.float32)\n",
    "attention_mask = tf.cast(input_ids != 0, dtype=tf.float32)\n",
    "head_mask = [None] * model.distilbert.num_hidden_layers\n",
    "\n",
    "transformer_output = model.distilbert.transformer([embeddings, attention_mask, head_mask], training=False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "pooled_output = transformer_output[:, 0]\n",
    "pooled_output = model.pre_classifier(pooled_output)\n",
    "logits = model.classifier(pooled_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(16, 2), dtype=float32, numpy=\n",
       "array([[-3.9101574 ,  3.8021853 ],\n",
       "       [ 3.8041704 , -2.7883568 ],\n",
       "       [ 1.9888983 , -1.4849629 ],\n",
       "       [ 3.4512417 , -2.530995  ],\n",
       "       [-2.443832  ,  2.26392   ],\n",
       "       [-4.1048136 ,  4.016287  ],\n",
       "       [-1.5832212 ,  1.3562965 ],\n",
       "       [ 2.9390903 , -2.1744635 ],\n",
       "       [-0.35801238,  0.43716034],\n",
       "       [-4.2704735 ,  4.176792  ],\n",
       "       [-3.441905  ,  3.28892   ],\n",
       "       [-2.6926115 ,  2.5206203 ],\n",
       "       [ 2.647887  , -1.9844501 ],\n",
       "       [-4.2204533 ,  4.191584  ],\n",
       "       [ 0.8310715 , -0.6177254 ],\n",
       "       [-4.196546  ,  4.136347  ]], dtype=float32)>"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
